昨天已經把位置編碼的演進介紹完了,需要考慮的點蠻多的。
參考來源:
https://www.cnblogs.com/rossiXYZ/p/18744797
https://medium.com/thedeephub/positional-encoding-explained-a-deep-dive-into-transformer-pe-65cfe8cfe10b
昨天有看過類似這張圖,這裡用底下這張圖講解。
當中的圖不像昨天二進制一樣只有0, 1,他是一個連續的,透過以下幾個觀點來了解:
觀念:
結論:
解決了昨天說的離散不連續的問題,值的範圍也有限,加上昨天說明過的,可以反應相對位置資訊。
這裡的實作只單做 positional encoding 這段,那整個是需要 token embedding 加起來才會得到最後的 word embedding。
這裡我們先實作 pe 的部分,步驟如下:
import torch
from torch import nn
# step1
class MyPositionEncoding(nn.Module):
def __init__(self):
super().__init__()
def forward(self, position_ids: torch.Tensor):
'''
B: batch size
L: seq len
position_ids: (B, L)
'''
pass
# step2
class MyPositionEncoding(nn.Module):
def __init__(self, max_seq_len, hidden_size):
super().__init__()
self.max_seq_len = max_seq_len
self.hidden_size = hidden_size
def forward(self, position_ids: torch.Tensor):
'''
B: batch size
L: seq len
position_ids: (B, L)
'''
pass
# step3 + step4
class MyPositionEncoding(nn.Module):
def __init__(self, max_seq_len, hidden_size):
super().__init__()
self.max_seq_len = max_seq_len
self.hidden_size = hidden_size
self.build_pos_enc()
def build_pos_enc(self):
# 初始化表格
pos_enc = torch.zeros(self.max_seq_len, self.hidden_size)
# 準備 position, shpae: L -> (L, 1) 用於等下相乘
position = torch.arange(0, self.max_seq_len).unsqueeze(1)
# inv 代表倒數的意思
# 因為兩個一組,所以維度0, 1 會用同一個,所以 arange 一次加 2
inv_freq = 1.0 / (10000 ** (torch.arange(0, self.hidden_size, 2).float() / self.hidden_size))
# print((torch.arange(0, self.hidden_size, 2).float() / self.hidden_size))
print(f'inv_freq: {inv_freq}')
# print(position * inv_freq)
# 偶數位使用 sin, 奇數位使用 cos → 放到 pos_enc 表格當中
# 將等號右邊的 sin 算完,放到左邊取出偶數位置的表格上
pos_enc[:, 0::2] = torch.sin(position * inv_freq)
print(f'已填入偶數位:\n {pos_enc}')
pos_enc[:, 1::2] = torch.cos(position * inv_freq)
print(f'再填入奇數位:\n {pos_enc}')
# 儲存起來
self.register_buffer('pos_enc', pos_enc)
def forward(self, position_ids: torch.Tensor):
'''
B: batch size
L: seq len
position_ids: (B, L)
'''
return self.pos_enc[position_ids]
# or
# return torch.embedding(self.pos_enc, position_ids)
測試程式
if __name__ == "__main__":
B, L, D = 2, 4, 6
x = torch.rand(B, L, D)
start_pos = 0
position_ids = torch.arange(
start = start_pos,
end = start_pos + L,
dtype = torch.long
).unsqueeze(0).expand(B, -1)
print(f'position_ids: {position_ids}')
pe = MyPositionEncoding(
max_seq_len = 10,
hidden_size = 6
)
y = pe(position_ids)
print(y.shape)
一樣可以照著步驟試著想想看做做看,不過是真的沒想到分步驟花的時間真的久,希望可以幫到你更好了解,明天我們先換換口味,今天先到這囉~